Skip to content

Commit 5152488

Browse files
authored
feat(spanner): add interval type support (#12009)
* feat(spanner): add interval type support * support , in decimal part of interval
1 parent 8e56f74 commit 5152488

File tree

4 files changed

+1448
-0
lines changed

4 files changed

+1448
-0
lines changed

spanner/integration_test.go

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ const (
7070
backupDDLStatements = "BACKUP_DDL_STATEMENTS"
7171
testTableDDLStatements = "TEST_TABLE_DDL_STATEMENTS"
7272
fkdcDDLStatements = "FKDC_DDL_STATEMENTS"
73+
intervalDDLStatements = "INTERVAL_DDL_STATEMENTS"
7374
testTableBitReversedSeqStatements = "TEST_TABLE_BIT_REVERSED_SEQUENCE_STATEMENTS"
7475
)
7576

@@ -325,6 +326,25 @@ var (
325326
)`,
326327
}
327328

329+
intervalDBStatements = []string{
330+
`CREATE TABLE IntervalTable (
331+
key STRING(MAX),
332+
create_time TIMESTAMP,
333+
expiry_time TIMESTAMP,
334+
expiry_within_month bool AS (expiry_time - create_time < INTERVAL 30 DAY),
335+
interval_array_len INT64 AS (ARRAY_LENGTH(ARRAY<INTERVAL>[INTERVAL '1-2 3 4:5:6' YEAR TO SECOND]))
336+
) PRIMARY KEY (key)`,
337+
}
338+
intervalDBPGStatements = []string{
339+
`CREATE TABLE IntervalTable (
340+
key text primary key,
341+
create_time timestamptz,
342+
expiry_time timestamptz,
343+
expiry_within_month bool GENERATED ALWAYS AS (expiry_time - create_time < INTERVAL '30' DAY) STORED,
344+
interval_array_len bigint GENERATED ALWAYS AS (ARRAY_LENGTH(ARRAY[INTERVAL '1-2 3 4:5:6'], 1)) STORED
345+
)`,
346+
}
347+
328348
statements = map[adminpb.DatabaseDialect]map[string][]string{
329349
adminpb.DatabaseDialect_GOOGLE_STANDARD_SQL: {
330350
singerDDLStatements: singerDBStatements,
@@ -334,6 +354,7 @@ var (
334354
testTableDDLStatements: readDBStatements,
335355
fkdcDDLStatements: fkdcDBStatements,
336356
testTableBitReversedSeqStatements: bitReverseSeqDBStatments,
357+
intervalDDLStatements: intervalDBStatements,
337358
},
338359
adminpb.DatabaseDialect_POSTGRESQL: {
339360
singerDDLStatements: singerDBPGStatements,
@@ -343,6 +364,7 @@ var (
343364
testTableDDLStatements: readDBPGStatements,
344365
fkdcDDLStatements: fkdcDBPGStatements,
345366
testTableBitReversedSeqStatements: bitReverseSeqDBPGStatments,
367+
intervalDDLStatements: intervalDBPGStatements,
346368
},
347369
}
348370

@@ -858,6 +880,195 @@ func TestIntegration_SingleUse_WithQueryOptions(t *testing.T) {
858880
}
859881
}
860882

883+
func TestIntegration_Interval(t *testing.T) {
884+
skipEmulatorTest(t)
885+
t.Parallel()
886+
887+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
888+
defer cancel()
889+
client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, statements[testDialect][intervalDDLStatements])
890+
defer cleanup()
891+
892+
m := InsertOrUpdate("IntervalTable", []string{"key", "create_time", "expiry_time"},
893+
[]interface{}{
894+
"test1",
895+
time.Date(2004, 11, 30, 4, 53, 54, 0, time.UTC),
896+
time.Date(2004, 12, 15, 4, 53, 54, 0, time.UTC),
897+
})
898+
_, err := client.Apply(ctx, []*Mutation{m})
899+
if err != nil {
900+
t.Fatal(err)
901+
}
902+
placeHolder := "@p1"
903+
timestampQuery := `TIMESTAMP('2004-11-30T10:23:54+0530')`
904+
if testDialect == adminpb.DatabaseDialect_POSTGRESQL {
905+
placeHolder = "$1"
906+
timestampQuery = `TIMESTAMPTZ '2004-11-30T10:23:54+0530'`
907+
}
908+
stmt := Statement{
909+
SQL: `SELECT expiry_within_month, interval_array_len FROM IntervalTable WHERE key = ` + placeHolder,
910+
Params: map[string]interface{}{
911+
"p1": "test1",
912+
},
913+
}
914+
iter := client.Single().Query(ctx, stmt)
915+
defer iter.Stop()
916+
917+
row, err := iter.Next()
918+
if err != nil {
919+
t.Fatal(err)
920+
}
921+
922+
var expiryWithinMonth bool
923+
var intervalArrayLen int64
924+
if err := row.Columns(&expiryWithinMonth, &intervalArrayLen); err != nil {
925+
t.Fatal(err)
926+
}
927+
928+
if !expiryWithinMonth {
929+
t.Error("expected expiry_within_month to be true")
930+
}
931+
if intervalArrayLen != 1 {
932+
t.Errorf("expected interval_array_len to be 1, got %d", intervalArrayLen)
933+
}
934+
935+
stmt = Statement{SQL: "SELECT INTERVAL '1' DAY + INTERVAL '1' MONTH AS Col1"}
936+
iter = client.Single().Query(ctx, stmt)
937+
defer iter.Stop()
938+
939+
row, err = iter.Next()
940+
if err != nil {
941+
t.Fatal(err)
942+
}
943+
944+
var interval Interval
945+
if err := row.Column(0, &interval); err != nil {
946+
t.Fatal(err)
947+
}
948+
949+
wantInterval := Interval{
950+
Months: 1,
951+
Days: 1,
952+
Nanos: big.NewInt(0),
953+
}
954+
955+
if interval.Months != wantInterval.Months || interval.Days != wantInterval.Days || interval.Nanos.Cmp(wantInterval.Nanos) != 0 {
956+
t.Errorf("got interval %+v, want %+v", interval, wantInterval)
957+
}
958+
959+
m = InsertOrUpdate("IntervalTable", []string{"key", "create_time", "expiry_time"},
960+
[]interface{}{
961+
"test2",
962+
time.Date(2004, 8, 30, 4, 53, 54, 0, time.UTC),
963+
time.Date(2004, 12, 15, 4, 53, 54, 0, time.UTC),
964+
})
965+
_, err = client.Apply(ctx, []*Mutation{m})
966+
if err != nil {
967+
t.Fatal(err)
968+
}
969+
970+
stmt = Statement{
971+
SQL: `SELECT COUNT(*) FROM IntervalTable WHERE create_time < ` + timestampQuery + ` - ` + placeHolder,
972+
Params: map[string]interface{}{
973+
"p1": Interval{Days: 30},
974+
},
975+
}
976+
iter = client.Single().Query(ctx, stmt)
977+
defer iter.Stop()
978+
979+
row, err = iter.Next()
980+
if err != nil {
981+
t.Fatal(err)
982+
}
983+
984+
var count int64
985+
if err := row.Column(0, &count); err != nil {
986+
t.Fatal(err)
987+
}
988+
989+
if count != 1 {
990+
t.Errorf("got count %d, want 1", count)
991+
}
992+
993+
stmt = Statement{
994+
SQL: "SELECT " + placeHolder,
995+
Params: map[string]interface{}{
996+
"p1": []Interval{
997+
{Months: 14, Days: 3, Nanos: big.NewInt(14706000000000)},
998+
{},
999+
{Months: -14, Days: -3, Nanos: big.NewInt(-14706000000000)},
1000+
{},
1001+
},
1002+
},
1003+
}
1004+
iter = client.Single().Query(ctx, stmt)
1005+
defer iter.Stop()
1006+
1007+
row, err = iter.Next()
1008+
if err != nil {
1009+
t.Fatal(err)
1010+
}
1011+
1012+
var intervals []NullInterval
1013+
if err := row.Column(0, &intervals); err != nil {
1014+
t.Fatal(err)
1015+
}
1016+
1017+
wantIntervals := []NullInterval{
1018+
{Interval: Interval{Months: 14, Days: 3, Nanos: big.NewInt(14706000000000)}, Valid: true},
1019+
{Valid: true},
1020+
{Interval: Interval{Months: -14, Days: -3, Nanos: big.NewInt(-14706000000000)}, Valid: true},
1021+
{Valid: true},
1022+
}
1023+
1024+
if len(intervals) != len(wantIntervals) {
1025+
t.Fatalf("got %d intervals, want %d", len(intervals), len(wantIntervals))
1026+
}
1027+
1028+
for i := range intervals {
1029+
if intervals[i].Valid != wantIntervals[i].Valid || intervals[i].Interval.Months != wantIntervals[i].Interval.Months ||
1030+
intervals[i].Interval.Days != wantIntervals[i].Interval.Days ||
1031+
(intervals[i].Interval.Nanos != nil && wantIntervals[i].Interval.Nanos != nil && intervals[i].Interval.Nanos.Cmp(wantIntervals[i].Interval.Nanos) != 0) {
1032+
t.Errorf("interval %d: got %+v, want %+v", i, intervals[i], wantIntervals[i])
1033+
}
1034+
}
1035+
1036+
stmt = Statement{
1037+
SQL: `SELECT ARRAY[CAST('P1Y2M3DT4H5M6.789123S' AS INTERVAL),
1038+
null,
1039+
CAST('P-1Y-2M-3DT-4H-5M-6.789123S' AS INTERVAL)] AS Col1`,
1040+
}
1041+
iter = client.Single().Query(ctx, stmt)
1042+
defer iter.Stop()
1043+
1044+
row, err = iter.Next()
1045+
if err != nil {
1046+
t.Fatal(err)
1047+
}
1048+
1049+
if err := row.Column(0, &intervals); err != nil {
1050+
t.Fatal(err)
1051+
}
1052+
1053+
wantIntervals = []NullInterval{
1054+
{Interval: Interval{Months: 14, Days: 3, Nanos: big.NewInt(14706789123000)}, Valid: true},
1055+
{Valid: false},
1056+
{Interval: Interval{Months: -14, Days: -3, Nanos: big.NewInt(-14706789123000)}, Valid: true},
1057+
}
1058+
1059+
if len(intervals) != len(wantIntervals) {
1060+
t.Fatalf("got %d intervals, want %d", len(intervals), len(wantIntervals))
1061+
}
1062+
1063+
for i := range intervals {
1064+
if intervals[i].Valid != wantIntervals[i].Valid || intervals[i].Interval.Months != wantIntervals[i].Interval.Months ||
1065+
intervals[i].Interval.Days != wantIntervals[i].Interval.Days ||
1066+
(intervals[i].Interval.Nanos != nil && wantIntervals[i].Interval.Nanos != nil && intervals[i].Interval.Nanos.Cmp(wantIntervals[i].Interval.Nanos) != 0) {
1067+
t.Errorf("interval %d: got %+v, want %+v", i, intervals[i], wantIntervals[i])
1068+
}
1069+
}
1070+
}
1071+
8611072
func TestIntegration_TransactionWasStartedInDifferentSession(t *testing.T) {
8621073
t.Parallel()
8631074
// TODO: unskip once https://b.corp.google.com/issues/309745482 is fixed

spanner/protoutils.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,10 @@ func dateType() *sppb.Type {
123123
return &sppb.Type{Code: sppb.TypeCode_DATE}
124124
}
125125

126+
func intervalType() *sppb.Type {
127+
return &sppb.Type{Code: sppb.TypeCode_INTERVAL}
128+
}
129+
126130
func listProto(p ...*proto3.Value) *proto3.Value {
127131
return &proto3.Value{Kind: &proto3.Value_ListValue{ListValue: &proto3.ListValue{Values: p}}}
128132
}

0 commit comments

Comments
 (0)